Switch

根据条件选择输入张量。该算子根据布尔条件 condition 的值,选择 input_xinput_y 作为输出。该算子不区分数据类型,适用于所有数据类型。

\[\begin{split}\text{output} = \begin{cases} \text{input\_x}, & \text{if } \text{condition} = \text{True} \\ \text{input\_y}, & \text{if } \text{condition} = \text{False} \end{cases}\end{split}\]

该算子不复制数据,只是将输出指针指向选中的输入张量。因此,输出张量共享输入张量的数据指针和元数据。

输入:
  • input_x - 第一个输入张量(TensorC* 类型)。当 condition 为 True 时被选中。

  • input_y - 第二个输入张量(TensorC* 类型)。当 condition 为 False 时被选中。

  • condition - 条件值(bool 类型),决定选择哪个输入张量。

输出:
  • output - 输出张量指针的指针(TensorC** 类型),指向选中的输入张量。

支持平台:

FT78NE MT7004

备注

  • 该算子不区分数据类型,适用于所有数据类型

  • 算子不复制数据,输出张量共享输入张量的数据指针

  • 输出张量的所有元数据(形状、数据类型、格式等)与选中的输入张量相同

共享存储版本:

void switch_s(TensorC *input_x, TensorC *input_y, TensorC **output, bool condition)

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3#include <switch.h>
 4
 5int main(int argc, char* argv[]) {
 6    TensorC input_x;
 7    TensorC input_y;
 8    TensorC* output;
 9
10    // 初始化 input_x
11    int x_shape[3] = {2, 3, 4};
12    memcpy(input_x.shape_, x_shape, 3 * sizeof(int));
13    input_x.shape_size_ = 3;
14    input_x.data_type_ = kNumberTypeFloat32;
15    input_x.format_ = Format_NCHW;
16    input_x.data_ = (void *)0xA0000000;
17    input_x.category_ = 0;  // 非常量
18    input_x.shape_changed_ = false;
19
20    // 初始化 input_y
21    int y_shape[3] = {2, 3, 4};
22    memcpy(input_y.shape_, y_shape, 3 * sizeof(int));
23    input_y.shape_size_ = 3;
24    input_y.data_type_ = kNumberTypeFloat32;
25    input_y.format_ = Format_NCHW;
26    input_y.data_ = (void *)0xB0000000;
27    input_y.category_ = 0;
28    input_y.shape_changed_ = false;
29
30    bool condition = true;  // 选择 input_x
31
32    switch_s(&input_x, &input_y, &output, condition);
33
34    return 0;
35}

私有存储版本:

void switch_p(TensorC *input_x, TensorC *input_y, TensorC **output, bool condition)

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3#include <switch.h>
 4
 5int main(int argc, char* argv[]) {
 6    TensorC input_x;
 7    TensorC input_y;
 8    TensorC* output;
 9
10    // 初始化 input_x
11    int x_shape[3] = {2, 3, 4};
12    memcpy(input_x.shape_, x_shape, 3 * sizeof(int));
13    input_x.shape_size_ = 3;
14    input_x.data_type_ = kNumberTypeFloat32;
15    input_x.format_ = Format_NCHW;
16    input_x.data_ = (void *)0x10000000;
17    input_x.category_ = 0;  // 非常量
18    input_x.shape_changed_ = false;
19
20    // 初始化 input_y
21    int y_shape[3] = {2, 3, 4};
22    memcpy(input_y.shape_, y_shape, 3 * sizeof(int));
23    input_y.shape_size_ = 3;
24    input_y.data_type_ = kNumberTypeFloat32;
25    input_y.format_ = Format_NCHW;
26    input_y.data_ = (void *)0x10001000;
27    input_y.category_ = 0;
28    input_y.shape_changed_ = false;
29
30    bool condition = true;  // 选择 input_x
31
32    switch_p(&input_x, &input_y, &output, condition);
33
34    return 0;
35}